import jax
from jax import random
from jax import numpy as jnp
from PyCmpltrtok.common import sep

key = random.PRNGKey(0)


def f(x):
    return jnp.multiply(x, x) / 2.0


x = random.normal(key, (5,))
v = jnp.ones(5)
print('x', x)
print('v', v)
print('f(x)', f(x))

sep('jvp')
fx, dfv = jax.jvp(f, (x, ), (v, ))
print('f(x)', fx)
print('dfv', dfv)

sep('vjp')
fx, jvp_fun = jax.vjp(f, x)
print('x', x)
print('fx', fx)
print('dfv', jvp_fun(v))
